!--------------------------------------------------------------------------------------------------!
!   CP2K: A general program to perform molecular dynamics simulations                              !
!   Copyright 2000-2024 CP2K developers group <https://cp2k.org>                                   !
!                                                                                                  !
!   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
!--------------------------------------------------------------------------------------------------!

! **************************************************************************************************
!> \brief Routines for the construction of the coefficients
!>      for the expansion  of the atomic
!>      densities rho1_hard and rho1_soft in terms of primitive spherical gaussians.
!> \par History
!>      05-2004 created
!> \author MI
! **************************************************************************************************
MODULE qs_oce_methods

   USE ai_overlap,                      ONLY: overlap
   USE ao_util,                         ONLY: exp_radius
   USE basis_set_types,                 ONLY: get_gto_basis_set,&
                                              gto_basis_set_type
   USE block_p_types,                   ONLY: block_p_type
   USE kinds,                           ONLY: dp
   USE orbital_pointers,                ONLY: init_orbital_pointers,&
                                              nco,&
                                              ncoset,&
                                              nso
   USE orbital_transformation_matrices, ONLY: orbtramat
   USE particle_types,                  ONLY: particle_type
   USE paw_basis_types,                 ONLY: get_paw_basis_info
   USE paw_proj_set_types,              ONLY: get_paw_proj_set,&
                                              paw_proj_set_type
   USE qs_kind_types,                   ONLY: get_qs_kind,&
                                              get_qs_kind_set,&
                                              qs_kind_type
   USE qs_neighbor_list_types,          ONLY: neighbor_list_set_p_type
   USE sap_kind_types,                  ONLY: clist_type,&
                                              sap_int_type,&
                                              sap_sort
#include "./base/base_uses.f90"

   IMPLICIT NONE

   PRIVATE

   ! Global parameters (only in this module)

   CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'qs_oce_methods'

   ! Public subroutines

   PUBLIC :: build_oce_matrices, proj_blk, prj_scatter, prj_gather

CONTAINS

! **************************************************************************************************
!> \brief ...
!> \param oces ...
!> \param atom_ka ...
!> \param atom_kb ...
!> \param rab ...
!> \param nder ...
!> \param sgf_list ...
!> \param nsgf_cnt ...
!> \param sgf_soft_only ...
!> \param eps_fit ...
! **************************************************************************************************
   SUBROUTINE build_oce_block(oces, atom_ka, atom_kb, rab, nder, sgf_list, nsgf_cnt, sgf_soft_only, &
                              eps_fit)

      TYPE(block_p_type), DIMENSION(:), POINTER          :: oces
      TYPE(qs_kind_type), POINTER                        :: atom_ka, atom_kb
      REAL(KIND=dp), DIMENSION(3)                        :: rab
      INTEGER, INTENT(IN)                                :: nder
      INTEGER, DIMENSION(:), INTENT(OUT)                 :: sgf_list
      INTEGER, INTENT(OUT)                               :: nsgf_cnt
      LOGICAL, INTENT(OUT)                               :: sgf_soft_only
      REAL(KIND=dp), INTENT(IN)                          :: eps_fit

      INTEGER :: first_col, ic, ider, ig1, igau, ip, ipgf, is, isgfb, isgfb_cnt, isp, jc, jset, &
         lds, lm, lpoint, lprj, lsgfb, lsgfb_cnt, lshell, m, m1, maxcob, maxder, maxlb, maxlprj, &
         maxnprja, maxsoa, msab, n, ncob, np_car, np_sph, nsatbas, nseta, nsetb, nsoatot, &
         ntotsgfb, sgf_hard_only
      INTEGER, DIMENSION(:), POINTER                     :: fp_cara, fp_spha, lb_max, lb_min, npgfb, &
                                                            nprjla, nsgfb
      INTEGER, DIMENSION(:, :), POINTER                  :: first_sgfb
      LOGICAL                                            :: calculate_forces, paw_atom_a, paw_atom_b
      REAL(KIND=dp)                                      :: dab, hard_radius_a, hard_radius_b, &
                                                            radius, rcprja
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: ovs, spa_sb
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :)     :: s
      REAL(KIND=dp), DIMENSION(:), POINTER               :: set_radius_b, zisomina, zisominb
      REAL(KIND=dp), DIMENSION(:, :), POINTER            :: csprj, rpgfb, rzetprja, spa_tmp, sphi_b, &
                                                            zetb, zetprja
      TYPE(gto_basis_set_type), POINTER                  :: basis_1c_a, orb_basis_b
      TYPE(paw_proj_set_type), POINTER                   :: paw_proj_a, paw_proj_b

      NULLIFY (basis_1c_a, paw_proj_a)
      CALL get_qs_kind(qs_kind=atom_ka, basis_set=basis_1c_a, basis_type="GAPW_1C", &
                       paw_proj_set=paw_proj_a, paw_atom=paw_atom_a, &
                       hard_radius=hard_radius_a)

      NULLIFY (orb_basis_b, paw_proj_b)
      CALL get_qs_kind(qs_kind=atom_kb, basis_set=orb_basis_b, &
                       paw_proj_set=paw_proj_b, paw_atom=paw_atom_b, &
                       hard_radius=hard_radius_b)

      IF (.NOT. paw_atom_a) RETURN

      NULLIFY (nprjla, fp_cara, fp_spha, rzetprja, zetprja)
      CALL get_paw_proj_set(paw_proj_set=paw_proj_a, csprj=csprj, maxl=maxlprj, &
                            nprj=nprjla, ncgauprj=np_car, nsgauprj=np_sph, nsatbas=nsatbas, rcprj=rcprja, &
                            first_prj=fp_cara, first_prjs=fp_spha, &
                            zisomin=zisomina, rzetprj=rzetprja, zetprj=zetprja)

      NULLIFY (first_sgfb, lb_max, lb_min, npgfb, nsgfb, rpgfb, sphi_b, set_radius_b, zetb, zisominb)
      CALL get_gto_basis_set(gto_basis_set=orb_basis_b, nset=nsetb, nsgf=ntotsgfb, &
                             set_radius=set_radius_b, lmax=lb_max, lmin=lb_min, &
                             npgf=npgfb, nsgf_set=nsgfb, pgf_radius=rpgfb, &
                             sphi=sphi_b, zet=zetb, first_sgf=first_sgfb, &
                             maxco=maxcob, maxl=maxlb)

      CALL get_gto_basis_set(gto_basis_set=basis_1c_a, nset=nseta, maxso=maxsoa)

      !  Add the block ab
      dab = SQRT(SUM(rab*rab))

      maxder = ncoset(nder)
      nsoatot = maxsoa*nseta
      maxnprja = SIZE(zetprja, 1)

      calculate_forces = .FALSE.
      IF (nder > 0) THEN
         calculate_forces = .TRUE.
      END IF

      lm = MAX(maxlb, maxlprj)
      lds = ncoset(lm + nder + 1)
      msab = MAX(maxnprja*ncoset(maxlprj), maxcob)

      ALLOCATE (s(lds, lds, ncoset(nder + 1)))
      ALLOCATE (spa_sb(np_car, ntotsgfb))
      ALLOCATE (spa_tmp(msab, msab*maxder))
      ALLOCATE (ovs(np_sph, maxcob*nsetb*maxder))

      m1 = 0
      nsgf_cnt = 0
      isgfb_cnt = 1
      sgf_hard_only = 0
      DO jset = 1, nsetb
         !
         ! Set the contribution list
         IF (hard_radius_a + set_radius_b(jset) >= dab) THEN
            isgfb = first_sgfb(1, jset)
            lsgfb = isgfb - 1 + nsgfb(jset)
            DO jc = isgfb, lsgfb
               nsgf_cnt = nsgf_cnt + 1
               sgf_list(nsgf_cnt) = jc
            END DO

            ! check if this function is hard
            radius = exp_radius(lb_max(jset), MAXVAL(zetb(1:npgfb(jset), jset)), eps_fit, 1.0_dp)
            IF (radius .LE. hard_radius_b) sgf_hard_only = sgf_hard_only + 1

            ! Integral between proj of iatom and primitives of jatom
            ! Calculate the primitives overlap
            spa_tmp = 0.0_dp
            ovs = 0.0_dp
            s = 0.0_dp
            ncob = npgfb(jset)*ncoset(lb_max(jset))
            isgfb = first_sgfb(1, jset)
            lsgfb = isgfb - 1 + nsgfb(jset)

            lsgfb_cnt = isgfb_cnt - 1 + nsgfb(jset)

            DO lprj = 0, maxlprj
               CALL overlap(lprj, lprj, nprjla(lprj), &
                            rzetprja(:, lprj), zetprja(:, lprj), &
                            lb_max(jset), lb_min(jset), npgfb(jset), &
                            rpgfb(:, jset), zetb(:, jset), &
                            -rab, dab, spa_tmp, &
                            nder, .TRUE., s, lds)
               DO ider = 1, maxder
                  is = (ider - 1)*SIZE(spa_tmp, 1)
                  isp = (ider - 1)*maxcob*nsetb
                  DO ipgf = 1, nprjla(lprj)
                     lpoint = ncoset(lprj - 1) + 1 + (ipgf - 1)*ncoset(lprj)
                     m = fp_spha(lprj) + (ipgf - 1)*nso(lprj)
                     DO ip = 1, npgfb(jset)
                        ic = (ip - 1)*ncoset(lb_max(jset))
                        igau = isp + ic + m1 + ncoset(lb_min(jset) - 1) + 1
                        ig1 = is + ic + ncoset(lb_min(jset) - 1) + 1
                        n = ncoset(lb_max(jset)) - ncoset(lb_min(jset) - 1)
                        ovs(m:m + nso(lprj) - 1, igau:igau + n - 1) = &
                           MATMUL(orbtramat(lprj)%slm(1:nso(lprj), 1:nco(lprj)), &
                                  spa_tmp(lpoint:lpoint + nco(lprj) - 1, ig1:ig1 + n - 1))
                     END DO
                  END DO
               END DO
            END DO

            IF (paw_atom_b) THEN
               CALL get_paw_proj_set(paw_proj_set=paw_proj_b, zisomin=zisominb)
               DO ipgf = 1, npgfb(jset)
                  DO lshell = lb_min(jset), lb_max(jset)
                     IF (zetb(ipgf, jset) >= zisominb(lshell)) THEN
                        n = ncoset(lb_max(jset)) - ncoset(lb_min(jset) - 1)
                        igau = n*(ipgf - 1) + ncoset(lshell - 1)
                        DO ider = 1, maxder
                           is = maxcob*(ider - 1)
                           isp = (ider - 1)*maxcob*nsetb
                           ovs(:, igau + 1 + isp + m1:igau + nco(lshell) + isp + m1) = 0.0_dp
                        END DO
                     END IF
                  END DO
               END DO
            END IF

            ! Contraction step (integrals and derivatives)
            DO ider = 1, maxder
               first_col = (ider - 1)*maxcob*nsetb + 1 + m1
               ! CALL dgemm("N", "N", np_sph, nsgfb(jset), ncob, &
               !           1.0_dp, ovs(1, first_col), SIZE(ovs, 1), &
               !           sphi_b(1, isgfb), SIZE(sphi_b, 1), &
               !           0.0_dp, spa_sb(1, isgfb), SIZE(spa_sb, 1))
               spa_sb(1:np_sph, isgfb:isgfb + nsgfb(jset) - 1) = &
                  MATMUL(ovs(1:np_sph, first_col:first_col + ncob - 1), &
                         sphi_b(1:ncob, isgfb:isgfb + nsgfb(jset) - 1))

               ! CALL dgemm("T", "N", nsatbas, nsgfb(jset), np_sph, &
               !          1.0_dp, csprj(1, 1), SIZE(csprj, 1), &
               !          spa_sb(1, isgfb), SIZE(spa_sb, 1), &
               !          1.0_dp, oces(ider)%block(1, isgfb_cnt), SIZE(oces(ider)%block, 1))
               oces(ider)%block(1:nsatbas, isgfb_cnt:isgfb_cnt + nsgfb(jset) - 1) = &
                  oces(ider)%block(1:nsatbas, isgfb_cnt:isgfb_cnt + nsgfb(jset) - 1) + &
                  MATMUL(TRANSPOSE(csprj(1:np_sph, 1:nsatbas)), &
                         spa_sb(1:np_sph, isgfb:isgfb + nsgfb(jset) - 1))
            END DO
            isgfb_cnt = isgfb_cnt + nsgfb(jset)
         END IF ! radius
         m1 = m1 + maxcob
      END DO !jset

      ! Check if the screened functions are all soft
      sgf_soft_only = .FALSE.
      IF (sgf_hard_only .EQ. 0) sgf_soft_only = .TRUE.

      DEALLOCATE (s, spa_sb, spa_tmp, ovs)

   END SUBROUTINE build_oce_block

! **************************************************************************************************
!> \brief ...
!> \param oceh ...
!> \param oces ...
!> \param atom_ka ...
!> \param sgf_list ...
!> \param nsgf_cnt ...
! **************************************************************************************************
   SUBROUTINE build_oce_block_local(oceh, oces, atom_ka, sgf_list, nsgf_cnt)

      TYPE(block_p_type), DIMENSION(:), POINTER          :: oceh, oces
      TYPE(qs_kind_type), POINTER                        :: atom_ka
      INTEGER, DIMENSION(:), INTENT(OUT)                 :: sgf_list
      INTEGER, INTENT(OUT)                               :: nsgf_cnt

      INTEGER                                            :: i, iset, isgfa, j, jc, lsgfa, maxlprj, &
                                                            maxso1a, n, nsatbas, nset1a, nseta, &
                                                            nsgfa
      INTEGER, DIMENSION(:), POINTER                     :: n2oindex, nsgf_seta
      INTEGER, DIMENSION(:, :), POINTER                  :: first_sgfa
      LOGICAL                                            :: paw_atom_a
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: prjloc_h, prjloc_s
      REAL(KIND=dp), DIMENSION(:, :), POINTER            :: local_oce_h, local_oce_s
      TYPE(gto_basis_set_type), POINTER                  :: basis_1c_a, orb_basis_a
      TYPE(paw_proj_set_type), POINTER                   :: paw_proj_a

      NULLIFY (orb_basis_a, basis_1c_a, paw_proj_a)
      CALL get_qs_kind(qs_kind=atom_ka, basis_set=orb_basis_a, &
                       paw_proj_set=paw_proj_a, paw_atom=paw_atom_a)

      IF (.NOT. paw_atom_a) RETURN

      CALL get_paw_proj_set(paw_proj_set=paw_proj_a, maxl=maxlprj)
      CALL get_qs_kind(qs_kind=atom_ka, basis_set=basis_1c_a, basis_type="GAPW_1C")
      CALL get_gto_basis_set(gto_basis_set=basis_1c_a, nset=nset1a, maxso=maxso1a)
      NULLIFY (n2oindex)
      CALL get_paw_basis_info(basis_1c_a, n2oindex=n2oindex, nsatbas=nsatbas)

      CALL get_gto_basis_set(gto_basis_set=orb_basis_a, first_sgf=first_sgfa, &
                             nsgf=nsgfa, nsgf_set=nsgf_seta, nset=nseta)

      NULLIFY (local_oce_h, local_oce_s)
      CALL get_paw_proj_set(paw_proj_set=paw_proj_a, &
                            local_oce_sphi_h=local_oce_h, &
                            local_oce_sphi_s=local_oce_s)

      ALLOCATE (prjloc_h(nset1a*maxso1a, nsgfa), prjloc_s(nset1a*maxso1a, nsgfa))
      prjloc_h = 0._dp
      prjloc_s = 0._dp

      nsgf_cnt = 0
      DO iset = 1, nseta
         isgfa = first_sgfa(1, iset)
         lsgfa = isgfa - 1 + nsgf_seta(iset)
         DO jc = isgfa, lsgfa
            nsgf_cnt = nsgf_cnt + 1
            sgf_list(nsgf_cnt) = jc
         END DO
         ! this asumes that the first sets are the same for basis_1c/orb_basis!
         n = maxso1a*(iset - 1)
         prjloc_h(n + 1:n + maxso1a, isgfa:lsgfa) = local_oce_h(1:maxso1a, isgfa:lsgfa)
         prjloc_s(n + 1:n + maxso1a, isgfa:lsgfa) = local_oce_s(1:maxso1a, isgfa:lsgfa)
      END DO

      DO i = 1, nsgfa
         DO j = 1, nsatbas
            jc = n2oindex(j)
            oceh(1)%block(j, i) = prjloc_h(jc, i)
            oces(1)%block(j, i) = prjloc_s(jc, i)
         END DO
      END DO

      DEALLOCATE (prjloc_h, prjloc_s)
      DEALLOCATE (n2oindex)

   END SUBROUTINE build_oce_block_local

! **************************************************************************************************
!> \brief ...
!> \param oceh ...
!> \param oces ...
!> \param atom_ka ...
!> \param sgf_list ...
!> \param nsgf_cnt ...
!> \param eps_fit ...
! **************************************************************************************************
   SUBROUTINE build_oce_block_1c(oceh, oces, atom_ka, sgf_list, nsgf_cnt, eps_fit)

      TYPE(block_p_type), DIMENSION(:), POINTER          :: oceh, oces
      TYPE(qs_kind_type), POINTER                        :: atom_ka
      INTEGER, DIMENSION(:), INTENT(OUT)                 :: sgf_list
      INTEGER, INTENT(OUT)                               :: nsgf_cnt
      REAL(KIND=dp), INTENT(IN)                          :: eps_fit

      INTEGER :: first_col, ic, ig1, igau, ip, ipgf, isgfb, isgfb_cnt, jc, jset, lds, lm, lpoint, &
         lprj, lsgfb, lsgfb_cnt, lshell, m, m1, maxcob, maxlb, maxlprj, maxnprja, maxsoa, msab, n, &
         ncob, np_car, np_sph, nsatbas, nseta, nsetb, nsoatot, ntotsgfb
      INTEGER, DIMENSION(:), POINTER                     :: fp_cara, fp_spha, lb_max, lb_min, npgfb, &
                                                            nprjla, nsgfb
      INTEGER, DIMENSION(:, :), POINTER                  :: first_sgfb
      LOGICAL                                            :: paw_atom_a
      REAL(KIND=dp)                                      :: radius, rc, rcprja
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: ovh, ovs, spa_sb
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :)     :: s
      REAL(KIND=dp), DIMENSION(:), POINTER               :: set_radius_b, zisominb
      REAL(KIND=dp), DIMENSION(:, :), POINTER            :: chprj, csprj, rpgfb, rzetprja, spa_tmp, &
                                                            sphi_b, zetb, zetprja
      TYPE(gto_basis_set_type), POINTER                  :: basis_1c_a, orb_basis_b
      TYPE(paw_proj_set_type), POINTER                   :: paw_proj_a

      NULLIFY (orb_basis_b, basis_1c_a, paw_proj_a)
      CALL get_qs_kind(qs_kind=atom_ka, paw_atom=paw_atom_a)

      IF (.NOT. paw_atom_a) RETURN

      CALL get_qs_kind(qs_kind=atom_ka, paw_proj_set=paw_proj_a)
      NULLIFY (nprjla, fp_cara, fp_spha, rzetprja, zetprja)
      CALL get_paw_proj_set(paw_proj_set=paw_proj_a, csprj=csprj, chprj=chprj, maxl=maxlprj, &
                            nprj=nprjla, ncgauprj=np_car, nsgauprj=np_sph, nsatbas=nsatbas, rcprj=rcprja, &
                            first_prj=fp_cara, first_prjs=fp_spha, &
                            rzetprj=rzetprja, zetprj=zetprja)

      CALL get_qs_kind(qs_kind=atom_ka, basis_set=orb_basis_b)
      NULLIFY (first_sgfb, lb_max, lb_min, npgfb, nsgfb, rpgfb, sphi_b, set_radius_b, zetb, zisominb)
      CALL get_gto_basis_set(gto_basis_set=orb_basis_b, nset=nsetb, nsgf=ntotsgfb, &
                             set_radius=set_radius_b, lmax=lb_max, lmin=lb_min, &
                             npgf=npgfb, nsgf_set=nsgfb, pgf_radius=rpgfb, &
                             sphi=sphi_b, zet=zetb, first_sgf=first_sgfb, &
                             maxco=maxcob, maxl=maxlb)

      CALL get_qs_kind(qs_kind=atom_ka, hard_radius=rc)

      CALL get_qs_kind(qs_kind=atom_ka, basis_set=basis_1c_a, basis_type="GAPW_1C")
      CALL get_gto_basis_set(gto_basis_set=basis_1c_a, nset=nseta, maxso=maxsoa)

      !  Add the block ab
      nsoatot = maxsoa*nseta
      maxnprja = SIZE(zetprja, 1)

      lm = MAX(maxlb, maxlprj)
      lds = ncoset(lm + 1)
      msab = MAX(maxnprja*ncoset(maxlprj), maxcob)

      ALLOCATE (s(lds, lds, 1))
      ALLOCATE (spa_sb(np_car, ntotsgfb))
      ALLOCATE (spa_tmp(msab, msab))
      ALLOCATE (ovs(np_sph, maxcob*nsetb))
      ALLOCATE (ovh(np_sph, maxcob*nsetb))

      m1 = 0
      nsgf_cnt = 0
      isgfb_cnt = 1
      DO jset = 1, nsetb
         !
         ! Set the contribution list
         isgfb = first_sgfb(1, jset)
         lsgfb = isgfb - 1 + nsgfb(jset)
         DO jc = isgfb, lsgfb
            nsgf_cnt = nsgf_cnt + 1
            sgf_list(nsgf_cnt) = jc
         END DO

         ! Integral between proj of iatom and primitives of iatom
         ! Calculate the primitives overlap
         spa_tmp = 0.0_dp
         ovs = 0.0_dp
         ovh = 0.0_dp
         s = 0.0_dp
         ncob = npgfb(jset)*ncoset(lb_max(jset))
         isgfb = first_sgfb(1, jset)
         lsgfb = isgfb - 1 + nsgfb(jset)

         lsgfb_cnt = isgfb_cnt - 1 + nsgfb(jset)

         DO lprj = 0, maxlprj
            CALL overlap(lprj, lprj, nprjla(lprj), &
                         rzetprja(:, lprj), zetprja(:, lprj), &
                         lb_max(jset), lb_min(jset), npgfb(jset), &
                         rpgfb(:, jset), zetb(:, jset), &
                         -(/0._dp, 0._dp, 0._dp/), 0.0_dp, spa_tmp, &
                         0, .TRUE., s, lds)
            DO ipgf = 1, nprjla(lprj)
               lpoint = ncoset(lprj - 1) + 1 + (ipgf - 1)*ncoset(lprj)
               m = fp_spha(lprj) + (ipgf - 1)*nso(lprj)
               DO ip = 1, npgfb(jset)
                  ic = (ip - 1)*ncoset(lb_max(jset))
                  igau = ic + m1 + ncoset(lb_min(jset) - 1) + 1
                  ig1 = ic + ncoset(lb_min(jset) - 1) + 1
                  n = ncoset(lb_max(jset)) - ncoset(lb_min(jset) - 1)
                  ovs(m:m + nso(lprj) - 1, igau:igau + n - 1) = &
                     MATMUL(orbtramat(lprj)%slm(1:nso(lprj), 1:nco(lprj)), &
                            spa_tmp(lpoint:lpoint + nco(lprj) - 1, ig1:ig1 + n - 1))
               END DO
            END DO
         END DO

         ovh(:, :) = ovs(:, :)

         CALL get_paw_proj_set(paw_proj_set=paw_proj_a, zisomin=zisominb)
         DO ipgf = 1, npgfb(jset)
            DO lshell = lb_min(jset), lb_max(jset)
               radius = exp_radius(lshell, zetb(ipgf, jset), eps_fit, 1.0_dp)
               IF (radius < rc) THEN
                  n = ncoset(lb_max(jset)) - ncoset(lb_min(jset) - 1)
                  igau = n*(ipgf - 1) + ncoset(lshell - 1)
                  ovs(:, igau + 1 + m1:igau + nco(lshell) + m1) = 0.0_dp
               END IF
            END DO
         END DO

         ! Contraction step (integrals and derivatives)
         first_col = 1 + m1
         spa_sb(1:np_sph, isgfb:isgfb + nsgfb(jset) - 1) = &
            MATMUL(ovs(1:np_sph, first_col:first_col + ncob - 1), &
                   sphi_b(1:ncob, isgfb:isgfb + nsgfb(jset) - 1))

         oces(1)%block(1:nsatbas, isgfb_cnt:isgfb_cnt + nsgfb(jset) - 1) = &
            oces(1)%block(1:nsatbas, isgfb_cnt:isgfb_cnt + nsgfb(jset) - 1) + &
            MATMUL(TRANSPOSE(csprj(1:np_sph, 1:nsatbas)), &
                   spa_sb(1:np_sph, isgfb:isgfb + nsgfb(jset) - 1))

         spa_sb(1:np_sph, isgfb:isgfb + nsgfb(jset) - 1) = &
            MATMUL(ovh(1:np_sph, first_col:first_col + ncob - 1), &
                   sphi_b(1:ncob, isgfb:isgfb + nsgfb(jset) - 1))

         oceh(1)%block(1:nsatbas, isgfb_cnt:isgfb_cnt + nsgfb(jset) - 1) = &
            oceh(1)%block(1:nsatbas, isgfb_cnt:isgfb_cnt + nsgfb(jset) - 1) + &
            MATMUL(TRANSPOSE(chprj(1:np_sph, 1:nsatbas)), &
                   spa_sb(1:np_sph, isgfb:isgfb + nsgfb(jset) - 1))

         isgfb_cnt = isgfb_cnt + nsgfb(jset)
         m1 = m1 + maxcob
      END DO !jset

      DEALLOCATE (s, spa_sb, spa_tmp, ovs)

   END SUBROUTINE build_oce_block_1c

! **************************************************************************************************
!> \brief Set up the sparse matrix for the coefficients of one center expansions
!>      This routine uses the same logic as the nonlocal pseudopotential
!> \param intac TYPE that holds the integrals (a=basis; c=projector)
!> \param calculate_forces ...
!> \param nder ...
!> \param qs_kind_set ...
!> \param particle_set ...
!> \param sap_oce ...
!> \param eps_fit ...
!> \par History
!>      02.2009 created
!> \author jgh
! **************************************************************************************************
   SUBROUTINE build_oce_matrices(intac, calculate_forces, nder, &
                                 qs_kind_set, particle_set, sap_oce, eps_fit)

      TYPE(sap_int_type), DIMENSION(:), POINTER          :: intac
      LOGICAL, INTENT(IN)                                :: calculate_forces
      INTEGER                                            :: nder
      TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
      TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
      TYPE(neighbor_list_set_p_type), DIMENSION(:), &
         POINTER                                         :: sap_oce
      REAL(KIND=dp), INTENT(IN)                          :: eps_fit

      CHARACTER(LEN=*), PARAMETER :: routineN = 'build_oce_matrices'

      INTEGER :: atom_a, atom_b, handle, i, iac, ikind, ilist, jkind, jneighbor, ldai, ldsab, &
         maxco, maxder, maxl, maxlgto, maxlprj, maxprj, maxsgf, maxsoa, maxsob, mlprj, natom, &
         ncoa_sum, nkind, nlist, nneighbor, nsatbas, nseta, nsetb, nsgf_cnt, nsgfa, nsobtot, slot
      INTEGER, ALLOCATABLE, DIMENSION(:)                 :: sgf_list
      INTEGER, DIMENSION(3)                              :: cell_b
      INTEGER, DIMENSION(:), POINTER                     :: fp_car, fp_sph, la_max, la_min, npgfa, &
                                                            nprjla, nsgf_seta
      INTEGER, DIMENSION(:, :), POINTER                  :: first_sgfa
      LOGICAL                                            :: local, paw_atom_b, sgf_soft_only
      REAL(KIND=dp)                                      :: dab, rcprj
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: sab, work
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :)     :: ai_work
      REAL(KIND=dp), DIMENSION(3)                        :: rab
      REAL(KIND=dp), DIMENSION(:), POINTER               :: set_radius_a
      REAL(KIND=dp), DIMENSION(:, :), POINTER            :: rpgfa, rzetprj, sphi_a, zeta, zetb
      TYPE(block_p_type), DIMENSION(:), POINTER          :: oceh, oces
      TYPE(clist_type), POINTER                          :: clist
      TYPE(gto_basis_set_type), POINTER                  :: orb_basis_paw, orb_basis_set
      TYPE(paw_proj_set_type), POINTER                   :: paw_proj_b
      TYPE(qs_kind_type), POINTER                        :: at_a, at_b, qs_kind

      IF (calculate_forces) THEN
         CALL timeset(routineN//"_forces", handle)
      ELSE
         CALL timeset(routineN, handle)
      END IF

      IF (ASSOCIATED(sap_oce)) THEN

         nkind = SIZE(qs_kind_set)
         natom = SIZE(particle_set)

         maxder = ncoset(nder)

         CALL get_qs_kind_set(qs_kind_set=qs_kind_set, &
                              maxco=maxco, &
                              maxlgto=maxlgto, &
                              maxlprj=maxlprj, &
                              maxco_proj=maxprj, &
                              maxsgf=maxsgf)

         maxl = MAX(maxlgto, maxlprj)
         CALL init_orbital_pointers(maxl + nder + 1)

         DO i = 1, nkind*nkind
            NULLIFY (intac(i)%alist, intac(i)%asort, intac(i)%aindex)
            intac(i)%nalist = 0
         END DO

         ! Allocate memory list
         DO slot = 1, sap_oce(1)%nl_size
            ikind = sap_oce(1)%nlist_task(slot)%ikind
            jkind = sap_oce(1)%nlist_task(slot)%jkind
            atom_a = sap_oce(1)%nlist_task(slot)%iatom
            nlist = sap_oce(1)%nlist_task(slot)%nlist
            ilist = sap_oce(1)%nlist_task(slot)%ilist
            nneighbor = sap_oce(1)%nlist_task(slot)%nnode

            iac = ikind + nkind*(jkind - 1)

            qs_kind => qs_kind_set(ikind)
            CALL get_qs_kind(qs_kind=qs_kind, basis_set=orb_basis_set)
            IF (.NOT. ASSOCIATED(orb_basis_set)) CYCLE
            qs_kind => qs_kind_set(jkind)
            NULLIFY (paw_proj_b)
            CALL get_qs_kind(qs_kind=qs_kind, paw_proj_set=paw_proj_b, paw_atom=paw_atom_b)
            IF (.NOT. paw_atom_b) CYCLE
            CALL get_qs_kind(qs_kind=qs_kind, basis_set=orb_basis_paw, basis_type="GAPW_1C")
            IF (.NOT. ASSOCIATED(orb_basis_paw)) CYCLE
            IF (.NOT. ASSOCIATED(intac(iac)%alist)) THEN
               intac(iac)%a_kind = ikind
               intac(iac)%p_kind = jkind
               intac(iac)%nalist = nlist
               ALLOCATE (intac(iac)%alist(nlist))
               DO i = 1, nlist
                  NULLIFY (intac(iac)%alist(i)%clist)
                  intac(iac)%alist(i)%aatom = 0
                  intac(iac)%alist(i)%nclist = 0
               END DO
            END IF
            IF (.NOT. ASSOCIATED(intac(iac)%alist(ilist)%clist)) THEN
               intac(iac)%alist(ilist)%aatom = atom_a
               intac(iac)%alist(ilist)%nclist = nneighbor
               ALLOCATE (intac(iac)%alist(ilist)%clist(nneighbor))
            END IF
         END DO

         ldsab = MAX(maxco, ncoset(maxlprj), maxsgf, maxprj)
         ldai = ncoset(maxl + nder + 1)

         !calculate the overlap integrals <a|p>
!$OMP PARALLEL DEFAULT(NONE) &
!$OMP SHARED (intac, ldsab, ldai, nder, nkind, maxder, ncoset, sap_oce, qs_kind_set, eps_fit) &
!$OMP PRIVATE (sab, work, ai_work, oceh, oces, slot, ikind, jkind, atom_a, atom_b, ilist, jneighbor, rab, cell_b, &
!$OMP          iac, dab, qs_kind, orb_basis_set, first_sgfa, la_max, la_min, ncoa_sum, maxsoa, npgfa, nseta, &
!$OMP          nsgfa, nsgf_seta, rpgfa, set_radius_a, sphi_a, zeta, paw_proj_b, paw_atom_b, orb_basis_paw, &
!$OMP          maxsob, nsetb, mlprj, nprjla, nsatbas, rcprj, fp_car, fp_sph, rzetprj, zetb, nsobtot, clist, &
!$OMP          sgf_list, at_a, at_b, local, i, sgf_soft_only, nsgf_cnt)

         ALLOCATE (sab(ldsab, ldsab*maxder), work(ldsab, ldsab*maxder))
         sab = 0.0_dp
         ALLOCATE (ai_work(ldai, ldai, ncoset(nder + 1)))
         ai_work = 0.0_dp
         ALLOCATE (oceh(maxder), oces(maxder))

!$OMP DO SCHEDULE(GUIDED)
         DO slot = 1, sap_oce(1)%nl_size
            ikind = sap_oce(1)%nlist_task(slot)%ikind
            jkind = sap_oce(1)%nlist_task(slot)%jkind
            atom_a = sap_oce(1)%nlist_task(slot)%iatom
            atom_b = sap_oce(1)%nlist_task(slot)%jatom
            ilist = sap_oce(1)%nlist_task(slot)%ilist
            jneighbor = sap_oce(1)%nlist_task(slot)%inode
            rab(1:3) = sap_oce(1)%nlist_task(slot)%r(1:3)
            cell_b(1:3) = sap_oce(1)%nlist_task(slot)%cell(1:3)

            iac = ikind + nkind*(jkind - 1)
            dab = SQRT(SUM(rab*rab))

            qs_kind => qs_kind_set(ikind)
            CALL get_qs_kind(qs_kind=qs_kind, basis_set=orb_basis_set)

            IF (.NOT. ASSOCIATED(orb_basis_set)) CYCLE
            CALL get_gto_basis_set(gto_basis_set=orb_basis_set, &
                                   first_sgf=first_sgfa, &
                                   lmax=la_max, &
                                   lmin=la_min, &
                                   nco_sum=ncoa_sum, &
                                   maxso=maxsoa, &
                                   npgf=npgfa, &
                                   nset=nseta, &
                                   nsgf=nsgfa, &
                                   nsgf_set=nsgf_seta, &
                                   pgf_radius=rpgfa, &
                                   set_radius=set_radius_a, &
                                   sphi=sphi_a, &
                                   zet=zeta)

            qs_kind => qs_kind_set(jkind)

            NULLIFY (paw_proj_b)
            CALL get_qs_kind(qs_kind=qs_kind, paw_proj_set=paw_proj_b, paw_atom=paw_atom_b)
            IF (.NOT. paw_atom_b) CYCLE

            CALL get_qs_kind(qs_kind=qs_kind, basis_set=orb_basis_paw, basis_type="GAPW_1C")
            IF (.NOT. ASSOCIATED(orb_basis_paw)) CYCLE
            CALL get_gto_basis_set(gto_basis_set=orb_basis_paw, maxso=maxsob, nset=nsetb)

            CALL get_paw_proj_set(paw_proj_set=paw_proj_b, &
                                  maxl=mlprj, &
                                  nprj=nprjla, &
                                  nsatbas=nsatbas, &
                                  rcprj=rcprj, &
                                  first_prj=fp_car, &
                                  first_prjs=fp_sph, &
                                  rzetprj=rzetprj, &
                                  zetprj=zetb)

            nsobtot = nsatbas

            clist => intac(iac)%alist(ilist)%clist(jneighbor)
            clist%catom = atom_b
            clist%cell = cell_b
            clist%rac = rab
            clist%nsgf_cnt = 0
            clist%maxac = 0.0_dp
            clist%maxach = 0.0_dp
            NULLIFY (clist%acint, clist%achint, clist%sgf_list)

            ALLOCATE (sgf_list(nsgfa))

            at_a => qs_kind_set(jkind)
            at_b => qs_kind_set(ikind)

            local = (atom_a == atom_b .AND. ALL(cell_b == 0))

            IF (local) THEN
               DO i = 1, maxder
                  ALLOCATE (oceh(i)%block(nsobtot, nsgfa), oces(i)%block(nsobtot, nsgfa))
                  oceh(i)%block = 0._dp
                  oces(i)%block = 0._dp
               END DO
               CALL build_oce_block_local(oceh, oces, at_a, sgf_list, nsgf_cnt)
               clist%nsgf_cnt = nsgf_cnt
               clist%sgf_soft_only = .FALSE.
               IF (nsgf_cnt > 0) THEN
                  ALLOCATE (clist%acint(nsgf_cnt, nsobtot, maxder), clist%sgf_list(nsgf_cnt))
                  clist%acint(:, :, :) = 0._dp
                  clist%sgf_list(:) = HUGE(0)
                  CPASSERT(nsgf_cnt == nsgfa)
                  ! *** Special case: A=B
                  ALLOCATE (clist%achint(nsgfa, nsobtot, maxder))
                  clist%achint = 0._dp
                  clist%acint(1:nsgfa, 1:nsobtot, 1) = TRANSPOSE(oces(1)%block(1:nsobtot, 1:nsgfa))
                  clist%achint(1:nsgfa, 1:nsobtot, 1) = TRANSPOSE(oceh(1)%block(1:nsobtot, 1:nsgfa))
                  clist%maxac = MAXVAL(ABS(clist%acint(:, :, 1)))
                  clist%maxach = 0._dp
                  clist%sgf_list(1:nsgf_cnt) = sgf_list(1:nsgf_cnt)
               END IF
               DO i = 1, maxder
                  DEALLOCATE (oceh(i)%block, oces(i)%block)
               END DO
            ELSE
               DO i = 1, maxder
                  ALLOCATE (oces(i)%block(nsobtot, nsgfa))
                  oces(i)%block = 0._dp
               END DO
               CALL build_oce_block(oces, at_a, at_b, rab, nder, sgf_list, nsgf_cnt, sgf_soft_only, eps_fit)
               clist%nsgf_cnt = nsgf_cnt
               clist%sgf_soft_only = sgf_soft_only
               IF (nsgf_cnt > 0) THEN
                  ALLOCATE (clist%acint(nsgf_cnt, nsobtot, maxder), clist%sgf_list(nsgf_cnt))
                  clist%acint(:, :, :) = 0._dp
                  clist%sgf_list(:) = HUGE(0)
                  DO i = 1, maxder
                     clist%acint(1:nsgf_cnt, 1:nsobtot, i) = TRANSPOSE(oces(i)%block(1:nsobtot, 1:nsgf_cnt))
                  END DO
                  clist%maxac = MAXVAL(ABS(clist%acint(:, :, 1)))
                  clist%maxach = 0._dp
                  clist%sgf_list(1:nsgf_cnt) = sgf_list(1:nsgf_cnt)
               END IF
               DO i = 1, maxder
                  DEALLOCATE (oces(i)%block)
               END DO
            END IF

            DEALLOCATE (sgf_list)

         END DO

         DEALLOCATE (sab, work, ai_work)
         DEALLOCATE (oceh, oces)
!$OMP END PARALLEL

         ! Setup sort index
         CALL sap_sort(intac)

      END IF

      CALL timestop(handle)

   END SUBROUTINE build_oce_matrices

! **************************************************************************************************
!> \brief Project a matrix block onto the local atomic functions.
!>
!> \param h_a ...
!> \param s_a ...
!> \param na ...
!> \param h_b ...
!> \param s_b ...
!> \param nb ...
!> \param blk ...
!> \param ldb ...
!> \param proj_h ...
!> \param proj_s ...
!> \param nso ...
!> \param len1 ...
!> \param len2 ...
!> \param fac ...
!> \param distab ...
!> \par History
!>      02.2009 created
!>      09.2016 use automatic arrays [M Tucker]
! **************************************************************************************************
   SUBROUTINE proj_blk(h_a, s_a, na, h_b, s_b, nb, blk, ldb, proj_h, proj_s, nso, len1, len2, fac, distab)

      INTEGER, INTENT(IN)                                :: na
      REAL(KIND=dp), INTENT(IN)                          :: s_a(na, *), h_a(na, *)
      INTEGER, INTENT(IN)                                :: nb
      REAL(KIND=dp), INTENT(IN)                          :: s_b(nb, *), h_b(nb, *)
      INTEGER, INTENT(IN)                                :: ldb
      REAL(KIND=dp), INTENT(IN)                          :: blk(ldb, *)
      INTEGER, INTENT(IN)                                :: nso
      REAL(KIND=dp), INTENT(INOUT)                       :: proj_s(nso, *), proj_h(nso, *)
      INTEGER, INTENT(IN)                                :: len1, len2
      REAL(KIND=dp), INTENT(IN)                          :: fac
      LOGICAL, INTENT(IN)                                :: distab

      REAL(KIND=dp)                                      :: buf1(len1), buf2(len2)

      IF (na .EQ. 0 .OR. nb .EQ. 0 .OR. nso .EQ. 0) RETURN

      ! Handle special cases
      IF (na .EQ. 1 .AND. nb .EQ. 1) THEN
         ! hard
         CALL dger(nso, nso, fac*blk(1, 1), h_a(1, 1), 1, h_b(1, 1), 1, proj_h(1, 1), nso)
         ! soft
         CALL dger(nso, nso, fac*blk(1, 1), s_a(1, 1), 1, s_b(1, 1), 1, proj_s(1, 1), nso)
      ELSE
         IF (distab) THEN
            ! hard
            CALL dgemm('N', 'N', na, nso, nb, fac, blk(1, 1), ldb, h_b(1, 1), nb, 0.0_dp, buf1(1), na)
            CALL dgemm('T', 'N', nso, nso, na, 1.0_dp, h_a(1, 1), na, buf1(1), na, 0.0_dp, buf2(1), nso)
            CALL daxpy(nso*nso, 1.0_dp, buf2(1), 1, proj_h(1, 1), 1)
            ! soft
            CALL daxpy(nso*nso, 1.0_dp, buf2(1), 1, proj_s(1, 1), 1)
         ELSE
            ! hard
            CALL dgemm('N', 'N', na, nso, nb, fac, blk(1, 1), ldb, h_b(1, 1), nb, 0.0_dp, buf1(1), na)
            CALL dgemm('T', 'N', nso, nso, na, 1.0_dp, h_a(1, 1), na, buf1(1), na, 1.0_dp, proj_h(1, 1), nso)
            ! soft
            CALL dgemm('N', 'N', na, nso, nb, fac, blk(1, 1), ldb, s_b(1, 1), nb, 0.0_dp, buf1(1), na)
            CALL dgemm('T', 'N', nso, nso, na, 1.0_dp, s_a(1, 1), na, buf1(1), na, 1.0_dp, proj_s(1, 1), nso)
         END IF
      END IF

   END SUBROUTINE proj_blk

! **************************************************************************************************
!> \brief ...
!> \param ain matrix in old indexing
!> \param aout matrix in new compressed indexing
!> \param atom ...
! **************************************************************************************************
   SUBROUTINE prj_gather(ain, aout, atom)

      REAL(KIND=dp), DIMENSION(:, :), INTENT(IN)         :: ain
      REAL(KIND=dp), DIMENSION(:, :), INTENT(INOUT)      :: aout
      TYPE(qs_kind_type), INTENT(IN)                     :: atom

      INTEGER                                            :: i, ip, j, jp, nbas
      INTEGER, DIMENSION(:), POINTER                     :: n2oindex
      TYPE(gto_basis_set_type), POINTER                  :: basis_1c

      NULLIFY (basis_1c)
      CALL get_qs_kind(qs_kind=atom, basis_set=basis_1c, basis_type="GAPW_1C")
      NULLIFY (n2oindex)
      CALL get_paw_basis_info(basis_1c, n2oindex=n2oindex, nsatbas=nbas)

      DO i = 1, nbas
         ip = n2oindex(i)
         DO j = 1, nbas
            jp = n2oindex(j)
            aout(j, i) = ain(jp, ip)
         END DO
      END DO

      DEALLOCATE (n2oindex)

   END SUBROUTINE prj_gather

! **************************************************************************************************
!> \brief ...
!> \param ain  matrix in new compressed indexing
!> \param aout matrix in old indexing (addup)
!> \param atom ...
! **************************************************************************************************
   SUBROUTINE prj_scatter(ain, aout, atom)

      REAL(KIND=dp), DIMENSION(:, :), INTENT(IN)         :: ain
      REAL(KIND=dp), DIMENSION(:, :), INTENT(INOUT)      :: aout
      TYPE(qs_kind_type), INTENT(IN)                     :: atom

      INTEGER                                            :: i, ip, j, jp, nbas
      INTEGER, DIMENSION(:), POINTER                     :: n2oindex
      TYPE(gto_basis_set_type), POINTER                  :: basis_1c

      NULLIFY (basis_1c)
      CALL get_qs_kind(qs_kind=atom, basis_set=basis_1c, basis_type="GAPW_1C")
      NULLIFY (n2oindex)
      CALL get_paw_basis_info(basis_1c, n2oindex=n2oindex, nsatbas=nbas)

      DO i = 1, nbas
         ip = n2oindex(i)
         DO j = 1, nbas
            jp = n2oindex(j)
            aout(jp, ip) = aout(jp, ip) + ain(j, i)
         END DO
      END DO

      DEALLOCATE (n2oindex)

   END SUBROUTINE prj_scatter

END MODULE qs_oce_methods
